1 Preparations

1.1 Notation

We use the following notation for the datasets in use. Note, however, that some of the functions were are going to use, use \(W\) as the treatment indictor. ,

  • \(D\) Treatment indicator (binary or multiarm)
  • \(y\) outcome
  • \(X\) features/controls

1.2 Seed

We clear the working directory (not strictly necessary) and specify the seed. Note that we will also set the seed before calling some functions such as the causal_forest() to ensure replicability.

rm(list=ls())
seed<-1909
set.seed(seed)

1.3 Libraries

Below we load the libraries that were going to use.

# loading & modifying data
library("readr")         # to read the data
library("dplyr")         # to manipulate data
library("fastDummies")   # create dummies
# charts & tables
library("ggplot2")       # to create charts
library("patchwork")     # to combine charts
library("flextable")     # design tables
library("modelsummary")  # structure tables
library("kableExtra")    # design table
library("estimatr")
library("ggpubr")
# regression & analysis
library("fixest")        # high dimensional FE
library("skimr")         # skim the data
# machine learning
library("policytree")    # policy tree (Athey & Wager, 2021)
library("grf")           # causal forest
library("rsample")       # data splitting 
library("randomForest")  # Traditional Random Forests
library("mlr3")          # learners
library("mlr3learners")  # learners
library("gbm")           # Generalized Boosted Regression
library("DoubleML")      # Double ML

1.4 Load and prepare data

1.4.1 Load data

First we load the datasets and remove rows with missing values.

# load full dataset

df_repl<-read_delim("../data/FARS-data-full-sample.txt",delim = "\t")%>%
              filter(year<2004)%>%
              select(-starts_with("imp"))
# load small dataset
df_sel<-read_delim("../data/FARS-data-selection-sample.txt",delim = "\t")%>%
              filter(year<2004)%>%
              select(-starts_with("imp"))
# remove rows with missing cases
df_repl<-df_repl[complete.cases(df_repl), ]
df_sel<-df_sel[complete.cases(df_sel), ]

# print number of obs
print(paste('Number of observations in the data:',nrow(df_repl),' (full sample);',nrow(df_sel), ' (selected/causal sample)'))
## [1] "Number of observations in the data: 38455  (full sample); 10328  (selected/causal sample)"

1.4.2 Manipulate data

The following block prepares the data. It is bigger than it needs to be and creates several data frames and matrices that we are not using currently, but that we might use at some point.

# Treatment indicators
df_repl<-df_repl%>%mutate(D=case_when(lapshould==1~"LapShoulderSeat",lapbelt==1~"Lapbelt",
                                      childseat==1~"Childseat",TRUE~"NONE"),
                          D=factor(D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat")),
                          Dbinary=case_when(lapshould==1~1,lapbelt==1~1,childseat==1~1,TRUE~0))
df_sel <-df_sel %>%mutate(D=case_when(lapshould==1~"LapShoulderSeat",lapbelt==1~"Lapbelt",
                                    childseat==1~"Childseat",TRUE~"NONE"),
                         D=factor(D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat")),
                         Dbinary=case_when(lapshould==1~1,lapbelt==1~1,childseat==1~1,TRUE~0))
# Convert categorical to indicators
df_repl<-dummy_cols(df_repl%>%select(-restraint))%>%select(-starts_with("D_"),-crashtm,-crashcar,-age,-vehicles1,-vehicles2)
df_sel<-dummy_cols(df_sel%>%select(-restraint))%>%select(-starts_with("D_"),-crashtm,-crashcar,-age,-vehicles1,-vehicles2)
# Training and test data
set.seed(seed)
df_repl_split <- initial_split(df_repl, prop = .5)
df_repl_train <- training(df_repl_split)
df_repl_test  <- testing(df_repl_split)
df_sel_split <- initial_split(df_sel, prop = .5)
df_sel_train <- training(df_sel_split)
df_sel_test  <- testing(df_sel_split)
# X Matrices
X_repl_train<-as.matrix(df_repl_train%>%select(-death,-D,-Dbinary,-childseat,-lapbelt,-lapshould))
X_repl_test<- as.matrix(df_repl_test%>%select(-death,-D,-Dbinary,-childseat,-lapbelt,-lapshould))
X_repl<- as.matrix(df_repl%>%select(-death,-D,-Dbinary,-childseat,-lapbelt,-lapshould))
X_sel_train<- as.matrix(df_sel_train%>%select(-death,-D,-Dbinary,-childseat,-lapbelt,-lapshould))
X_sel_test<-  as.matrix(df_sel_test%>%select(-death,-D,-Dbinary,-childseat,-lapbelt,-lapshould))
X_sel<-  as.matrix(df_sel%>%select(-death,-D,-Dbinary,-childseat,-lapbelt,-lapshould))
X_repl_train_nocontrols<-as.matrix(rep(1,nrow(X_repl_train)))
X_repl_test_nocontrols<- as.matrix(rep(1,nrow(X_repl_test)))
X_repl_nocontrols<-as.matrix(rep(1,nrow(X_repl)))
X_sel_train_nocontrols<- as.matrix(rep(1,nrow(X_sel_train)))
X_sel_test_nocontrols<-  as.matrix(rep(1,nrow(X_sel_test)))
X_sel_nocontrols<-as.matrix(rep(1,nrow(X_sel)))
# D matrices
D_repl_train<-factor(df_repl_train$D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat"))
D_repl_test<-factor(df_repl_train$D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat"))
D_repl<-factor(df_repl$D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat"))
D_sel_train<-factor(df_sel_train$D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat"))
D_sel_test<-factor(df_sel_train$D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat"))
D_sel<-factor(df_sel$D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat"))
D_binary_repl_train<-as.matrix(df_repl_train%>%select(Dbinary))
D_binary_repl_test<- as.matrix(df_repl_test%>%select(Dbinary))
D_binary_repl<- as.matrix(df_repl%>%select(Dbinary))
D_binary_sel_train<- as.matrix(df_sel_train%>%select(Dbinary))
D_binary_sel_test<-  as.matrix(df_sel_test%>%select(Dbinary))
D_binary_sel<-  as.matrix(df_sel%>%select(Dbinary))
# Y matrices
Y_repl_train<-as.matrix(df_repl_train%>%select(death))
Y_repl_test<- as.matrix(df_repl_test%>%select(death))
Y_repl<- as.matrix(df_repl%>%select(death))
Y_sel_train<- as.matrix(df_sel_train%>%select(death))
Y_sel_test<-  as.matrix(df_sel_test%>%select(death))
Y_sel<-  as.matrix(df_sel%>%select(death))

1.5 Summary statistics

Let’s calculate some summary statistics on the datasets. This is mostly done to make Alessandro jealous of the built-in stuff in R. However,a s many variables are binary it doesn’t make much sense to show the histograms.

tmp <- df_sel%>%select(splmU55,thoulbs_I,modelyr,year,numcrash,weekend,lowviol,highviol,ruralrd,frimp,suv,death)
# remove missing and rescale
tmp_list <- lapply(tmp, na.omit)
tmp_list <- lapply(tmp_list, scale)

emptycol = function(x) " "
datasummary(splmU55+thoulbs_I+modelyr+year+numcrash+weekend+lowviol+highviol+ruralrd+frimp+suv+death ~ Mean + SD + 
              Heading("Boxplot") * emptycol + Heading("Histogram") * emptycol, data = tmp) %>%
    column_spec(column = 4, image = spec_boxplot(tmp_list)) %>%
    column_spec(column = 5, image = spec_hist(tmp_list))
Mean SD Boxplot Histogram
splmU55 0.88 0.33
thoulbs_I 2.45 1.54
modelyr 1987.09 8.30
year 1993.42 7.13
numcrash 6.63 4.51
weekend 0.40 0.49
lowviol 0.29 0.45
highviol 0.08 0.27
ruralrd 0.08 0.28
frimp 0.67 0.47
suv 0.10 0.29
death 0.04 0.20

2 Causal Forest out of the box

We first estimate a causal forest where we, more or less, use all default settings.

2.1 Estimate Forest

In the block below we estimate teh causal forest. With tune.parameters = "all" we ask R to find the optimal parameter settings using cross-validation (tuning) on 50 forests with 200 trees. The AIPW ATE is -0.043 (0.007). Note that we get the warning that propensity scores are between 0.032 and 0.99.

cfbinary<- causal_forest(X=X_sel,Y=Y_sel, W=D_binary_sel,tune.parameters = "all")
average_treatment_effect(cfbinary)
##     estimate      std.err 
## -0.043050487  0.007757556

2.2 Tuned parameter settings

As the block below shows, the tuning did not lead to parameter settings that differ from the default. We will return to that later, but it is wort noting that the minimum node (and thereby leaf) size is 5 which can therefore implies that the leafs can be quite small and we should worry about over fitting.

cfbinary$tuning.output               
## Tuning status: default.
## This indicates tuning was attempted. However, we could not find parameters that were expected to perform better than default: 
## 
## sample.fraction: 0.5
##  mtry: 26
##  min.node.size: 5
##  honesty.fraction: 0.5
##  honesty.prune.leaves: TRUE
##  alpha: 0.05
##  imbalance.penalty: 0

2.3 Plot tree

In the block below we plot one of the trees. This is not super helpful because it becomes so big that we can’t read it. However, it does give some suggestion on shallowness which we cannot set here. We observe that the tree is quite deep.

# Extract the first tree
treeex1<-get_tree(cfbinary,1)
# Plot the tree
plot(treeex1)

Causal Tree Illustration

2.4 Omnibus test

Next we run the the diagnostic test by runing a regression of the cate on the mean and the predicted deviation. If the forest captures the mean and the heterogeneity well, both coefficients should be 1. We are not far off!

test_calibration(cfbinary)
## 
## Best linear fit using forest predictions (on held-out data)
## as well as the mean forest prediction as regressors, along
## with one-sided heteroskedasticity-robust (HC3) SEs:
## 
##                                Estimate Std. Error t value    Pr(>t)    
## mean.forest.prediction          0.94116    0.12218  7.7032 7.243e-15 ***
## differential.forest.prediction  1.13675    0.28226  4.0273 2.841e-05 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

2.5 Influential features

Below we plot the features by how often they were used to create a split. The speed limit is often chosen, followed by weight.

# Get importance
importance=variable_importance(cfbinary)

var_imp <- data.frame(importance=importance,names=colnames(X_sel_train))
ggplot(var_imp,aes(x= reorder(names,importance),y=importance))+
  geom_bar(stat="identity",fill="#f56c42",color="white")+
  theme_minimal()+
  theme(axis.text.x = element_text(angle=45,vjust = 1, hjust=1))+
  labs(x=" ")+
  coord_flip()

2.6 Cate distribution

Below we show the distribution of the CATEs.

# get predictions
cate<-data.frame(sample="CATEs",tau=predict(cfbinary)$predictions)
# histogram all
ggplot(cate,aes(x=tau))+
   geom_histogram(aes(y=..count../sum(..count..)),bins=100,alpha=0.95, position = "identity",
                  fill="#f56c42",color="white",size=.2)+
  theme_minimal()+
  labs(title=" ",x="Conditional Average Treatment Effect",y="Density")

2.7 Plot quartiles of CATEs

We now split the sample in four quartiles according to the size of the CATEs.

# Split sample in 5 groups based on cates
df_sel["categroup"] <- factor(ntile(predict(cfbinary)$predictions, n=4))
# calculate AIPW for each sub group
estimated_aipw_ate <- lapply(
  seq(4), function(w) {
  ate <- average_treatment_effect(cfbinary, subset = df_sel$categroup == w,method = "AIPW")
})
# Combine in data da frame
estimated_aipw_ate <- data.frame(do.call(rbind, estimated_aipw_ate))
estimated_aipw_ate$Ntile <- as.numeric(rownames(estimated_aipw_ate))
estimated_aipw_ate$type<-"AIPW"
# Mean of CATES
df_sel["cate"]<-predict(cfbinary)$predictions
cates<-df_sel%>%group_by(categroup)%>%summarise(estimate=mean(cate))%>%rename(Ntile=categroup)%>%mutate(std.err=NA,type="CATE")

dfplot<-rbind(estimated_aipw_ate,cates)
# create plot
ggplot(dfplot,aes(color=type)) +
  geom_pointrange(aes(x = Ntile, y = estimate, ymax = estimate + 1.96 * `std.err`, ymin = estimate - 1.96 * `std.err`), 
                  size = 1,
                  position = position_dodge(width = .5)) +
  theme_minimal() +theme(legend.position = "top")+
  geom_hline(yintercept=0,linetype="dashed")+
  labs(color="",x = "Quartile", y = "AIPW ATE", title = "AIPW ATEs by  quartiles of the conditional average treatment effect")

Below we compare the characteristics of the first and fourth quartile above.

# create table
datasummary_balance(~categroup,
                    data = sumstatdata<-df_sel%>%filter(categroup%in%c(1,4))%>%
                      mutate(categroup=ifelse(categroup==1,1,4))%>%select(-D),
                    title = "Comparison of the first vs fourth quartile",
                    fmt= '%.3f',
                    dinm_statistic = "p.value")
Comparison of the first vs fourth quartile
1
4
Mean Std. Dev. Mean Std. Dev. Diff. in Means p
year 1996.266 4.937 1990.663 8.963 −5.604 0.000
passgcar 0.712 0.453 0.545 0.498 −0.168 0.000
suv 0.073 0.261 0.146 0.353 0.072 0.000
weekend 0.423 0.494 0.382 0.486 −0.041 0.003
frimp 0.636 0.481 0.750 0.433 0.114 0.000
indfrimp 0.185 0.388 0.134 0.341 −0.051 0.000
rearimp 0.050 0.217 0.029 0.169 −0.020 0.000
indrearimp 0.049 0.216 0.030 0.171 −0.019 0.000
rsimp 0.014 0.116 0.009 0.094 −0.005 0.113
lsimp 0.035 0.184 0.031 0.172 −0.005 0.349
death 0.074 0.262 0.024 0.154 −0.050 0.000
modelyr 1990.529 6.165 1984.908 10.129 −5.622 0.000
childseat 0.301 0.459 0.227 0.419 −0.075 0.000
lapbelt 0.162 0.368 0.129 0.335 −0.033 0.001
lapshould 0.250 0.433 0.172 0.377 −0.078 0.000
ruralrd 0.041 0.198 0.122 0.327 0.081 0.000
row1 0.297 0.457 0.315 0.465 0.018 0.165
backright 0.251 0.433 0.249 0.433 −0.002 0.898
backleft 0.281 0.449 0.222 0.415 −0.059 0.000
backother 0.019 0.138 0.032 0.176 0.013 0.004
male 0.507 0.500 0.565 0.496 0.058 0.000
missweight 0.216 0.411 0.261 0.439 0.045 0.000
thoulbs_I 2.209 1.288 2.817 1.776 0.608 0.000
numcrash 7.058 6.930 6.465 2.288 −0.593 0.000
drivebelt 0.743 0.437 0.577 0.494 −0.166 0.000
splmU55 0.586 0.493 1.000 0.000 0.414 0.000
lowviol 0.301 0.459 0.239 0.427 −0.061 0.000
highviol 0.088 0.283 0.043 0.203 −0.045 0.000
Dbinary 0.713 0.452 0.527 0.499 −0.186 0.000
crashtm_1_day 0.761 0.426 0.918 0.274 0.157 0.000
crashtm_2_night 0.204 0.403 0.072 0.259 −0.132 0.000
crashtm_3_morn 0.034 0.181 0.009 0.096 −0.025 0.000
cate -0.091 0.019 -0.017 0.011 0.073 0.000

2.8 Mean CATEs by covariates

Below we plot the average CATE by speed limit and weight.

df_sel["tau"]<-predict(cfbinary)$predictions
df_sel_train_col<-df_sel%>%
  group_by(modelyr,splmU55)%>%
  summarise(tau=mean(tau))
p1<-ggplot(df_sel_train_col,aes(x=modelyr,y=tau,color=as.factor(splmU55)))+geom_point()+
  ylim(-0.125,0)+theme_classic()
df_sel_train_col<-df_sel%>%
  group_by(year,splmU55)%>%
  summarise(tau=mean(tau))
p2<-ggplot(df_sel_train_col,aes(x=year,y=tau,color=as.factor(splmU55)))+geom_point()+
  ylim(-0.125,0)+labs(y="")+theme_classic()+theme(axis.text.y=element_blank())
df_sel_train_col<-df_sel
df_sel_train_col["xtile"] <- ntile(as.numeric(df_sel_train_col$thoulbs_I), n=50)
df_sel_train_col<-df_sel_train_col%>%
  group_by(xtile)%>%
  summarise(thoulbs_I=mean(thoulbs_I),tau=mean(tau))%>%filter(thoulbs_I!=0)
p3<-ggplot(df_sel_train_col,aes(x=thoulbs_I,y=tau))+geom_point()+
  ylim(-0.125,0)+labs(y="")+theme_classic()+theme(axis.text.y=element_blank())
ggarrange(p1, p2, p3, ncol=3, nrow=1, common.legend = TRUE, legend="bottom")

2.9 CATE distribution by speed limit

And finally we plot the distribution of the CATEs by speed limit.

# get predictions
cate<-data.frame(sample="CATEs",splmU55=df_sel$splmU55,tau=predict(cfbinary)$predictions)
# histogram all
ggplot(cate,aes(x=tau,fill=as.factor(splmU55),group=splmU55))+
   geom_histogram(aes(y=..count../sum(..count..)),bins=100,alpha=0.75, position = "identity",
                 color="white",size=0.1)+
  theme_minimal()+theme(legend.position="top")+
  labs(title=" ",x="Conditional Average Treatment Effect",y="Density", fill="splmU55")

2.10 Policy learning

Now to policy learning. We want to find the optimal policy. However, given that the predicted CATEs are positive for almost everyone it is expected that the unconditional optimal policy is close to treating everyone.

2.10.1 Features to consider

We only select a subset of features to base our policy on. We will for example not consider the type of accident because it doesn’t seem sensible to create a policy saying that you should have seatbelts if you were going to have a specific type of accident.

##  [1] "passgcar"        "suv"             "weekend"         "modelyr"        
##  [5] "ruralrd"         "row1"            "backright"       "backleft"       
##  [9] "backother"       "male"            "missweight"      "thoulbs_I"      
## [13] "numcrash"        "drivebelt"       "splmU55"         "lowviol"        
## [17] "highviol"        "crashtm_1_day"   "crashtm_2_night" "crashtm_3_morn"

2.10.2 Doubly robust scores

Next we compute the doubly robust scores. These scores are used to feed into the algorithm.

2.10.3 Find optimal policy

We now feed the scores to the algorithm that maximizes the difference between treating those directed by the features vs treating a random sample of the population. We set the depth to 2, meaning that the complexity of the policy is rather limited.

## policy_tree object 
## Tree depth:  2 
## Actions:  1:  2: Dbinary 
## Variable splits: 
## (1) split_variable: thoulbs_I  split_value: 1.939 
##   (2) split_variable: thoulbs_I  split_value: 1.869 
##     (4) * action: 2 
##     (5) * action: 1 
##   (3) split_variable: backother  split_value: 0 
##     (6) * action: 2 
##     (7) * action: 1

2.10.4 Plot optimal policy

Below we simply plot the policy. Action 2 is treating. Action 1 is not treating.

2.10.5 Fraction treated

How many are treated?

df_sel["policy"]<-predict(opt.tree, df_sel%>%select(colnames(X_pol)))
df_sel<-df_sel%>%mutate(policy=policy-1)
mean(df_sel$policy)
## [1] 0.9772463

We observe that almost everyone is treated, as expected.

2.10.6 Advantage of policy

What is the benefit for those treated compared to those not treated?

get_advantage = function(policy) {
    benefits<-( policy ) *Gamma.dr

    # Treated
    ATT = mean(Gamma.dr[policy==1])
    ATTse = sqrt(var(Gamma.dr[policy==1]) / length(Gamma.dr[policy==1]))
    # Untreated
    ATU = mean(Gamma.dr[policy==0])
    ATUse = sqrt(var(Gamma.dr[policy==0]) / length(Gamma.dr[policy==0]))
    # output
    c(att=ATT,attse=ATTse,atu=ATU,atuse=ATUse)
}

get_advantage(df_sel$policy)
##          att        attse          atu        atuse 
## -0.050160851  0.006408996  0.262332109  0.200576933

2.11 Policy learning with a cost

We observe above that we treated basically everyone, as expected. But that also assumed that the policy was free. What if it was costly? We incldue a cost by simply reducing the benefit of treatment.

2.11.1 Doubly robust scores with costs

We add a cost corresponding to reducing the benefit by 0.04, which corresponds to a pretty high cost.

2.11.2 Find optimal policy

We now feed the scores to the algorithm that maximizes the difference between treating those directed by the features vs treating a random sample of the population. We set the depth to 2, meaning that the complexity of the policy is rather limited.

## policy_tree object 
## Tree depth:  2 
## Actions:  1:  2: Dbinary 
## Variable splits: 
## (1) split_variable: thoulbs_I  split_value: 1.962 
##   (2) split_variable: splmU55  split_value: 0 
##     (4) * action: 2 
##     (5) * action: 1 
##   (3) split_variable: thoulbs_I  split_value: 3.281 
##     (6) * action: 2 
##     (7) * action: 1

2.11.3 Plot optimal policy

Below we simply plot the policy. Action 2 is treating. Action 1 is not treating.

2.11.4 Fraction treated

df_sel["policy"]<-predict(opt.tree, df_sel%>%select(colnames(X_pol)))
df_sel<-df_sel%>%mutate(policy=policy-1)
mean(df_sel$policy)
## [1] 0.4469404

We now only treat half of the sample!

2.11.5 Advantage of policy

get_advantage(df_sel$policy)
##         att       attse         atu       atuse 
## -0.08364524  0.01231361 -0.01024492  0.00986525

And those that we treat now benefit a lot. But those that we don’t treat would also benefit somewhat.

# get predictions
cate<-data.frame(sample="CATEs",treated=df_sel$policy,tau=predict(cfbinary)$predictions)
# histogram all
ggplot(cate,aes(x=tau,fill=as.factor(treated),group=treated))+
   geom_histogram(aes(y=..count../sum(..count..)),bins=100,alpha=0.75, position = "identity",
                 color="white",size=0.1)+
  theme_minimal()+theme(legend.position="top")+
  labs(title=" ",x="Conditional Average Treatment Effect",y="Density", fill="Treated")

3 CF with adhoc settings

3.1 Forest with ad hoc settings

We set the minimum node size 50 to avoid overfitting.

cfbinary<- causal_forest(X=X_sel,Y=Y_sel, W=D_binary_sel,min.node.size=50)
average_treatment_effect(cfbinary)
##     estimate      std.err 
## -0.043945503  0.007557879

3.2 Plot tree

We can get an indication on how increasing the minimum node size from 5 to 50 reduced the complexity of the tree consdierably.

# Extract the first tree
treeex1<-get_tree(cfbinary,1)
# Plot the tree
plot(treeex1)

Causal Tree Illustration

3.3 Omnibus test

The omnibus test is still pretty good.

test_calibration(cfbinary)
## 
## Best linear fit using forest predictions (on held-out data)
## as well as the mean forest prediction as regressors, along
## with one-sided heteroskedasticity-robust (HC3) SEs:
## 
##                                Estimate Std. Error t value    Pr(>t)    
## mean.forest.prediction          0.97676    0.12214  7.9972 7.044e-16 ***
## differential.forest.prediction  1.27208    0.55395  2.2964   0.01084 *  
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

3.4 Influential features

The speed limit is still number 1…

# Get importance
importance=variable_importance(cfbinary)

var_imp <- data.frame(importance=importance,names=colnames(X_sel_train))
ggplot(var_imp,aes(x= reorder(names,importance),y=importance))+
  geom_bar(stat="identity",fill="#f56c42",color="white")+
  theme_minimal()+
  theme(axis.text.x = element_text(angle=45,vjust = 1, hjust=1))+
  labs(x=" ")+
  coord_flip()

3.5 Cate distribution

The CATE distribution is slightly different.

# get predictions
cate<-data.frame(sample="CATEs",tau=predict(cfbinary)$predictions)
# histogram all
ggplot(cate,aes(x=tau))+
   geom_histogram(aes(y=..count../sum(..count..)),bins=100,alpha=0.95, position = "identity",
                  fill="#f56c42",color="white",size=.2)+
  theme_minimal()+
  labs(title=" ",x="Conditional Average Treatment Effect",y="Density")

3.6 Plot quartiles of CATEs

# Split sample in 5 groups based on cates
df_sel["categroup"] <- factor(ntile(predict(cfbinary)$predictions, n=4))
# calculate AIPW for each sub group
estimated_aipw_ate <- lapply(
  seq(4), function(w) {
  ate <- average_treatment_effect(cfbinary, subset = df_sel$categroup == w,method = "AIPW")
})
# Combine in data da frame
estimated_aipw_ate <- data.frame(do.call(rbind, estimated_aipw_ate))
estimated_aipw_ate$Ntile <- as.numeric(rownames(estimated_aipw_ate))
estimated_aipw_ate$type<-"AIPW"
# Mean of CATES
df_sel["cate"]<-predict(cfbinary)$predictions
cates<-df_sel%>%group_by(categroup)%>%summarise(estimate=mean(cate))%>%rename(Ntile=categroup)%>%mutate(std.err=NA,type="CATE")

dfplot<-rbind(estimated_aipw_ate,cates)
# create plot
ggplot(dfplot,aes(color=type)) +
  geom_pointrange(aes(x = Ntile, y = estimate, ymax = estimate + 1.96 * `std.err`, ymin = estimate - 1.96 * `std.err`), 
                  size = 1,
                  position = position_dodge(width = .5)) +
  theme_minimal() +theme(legend.position = "top")+
  geom_hline(yintercept=0,linetype="dashed")+
  labs(color="",x = "Quartile", y = "AIPW ATE", title = "AIPW ATEs by  quartiles of the conditional average treatment effect")

# create table
datasummary_balance(~categroup,
                    data = sumstatdata<-df_sel%>%filter(categroup%in%c(1,4))%>%
                      mutate(categroup=ifelse(categroup==1,1,4))%>%select(-D),
                    title = "Comparison of the first vs fourth quartile",
                    fmt= '%.3f',
                    dinm_statistic = "p.value")
Comparison of the first vs fourth quartile
1
4
Mean Std. Dev. Mean Std. Dev. Diff. in Means p
year 1997.018 4.384 1988.254 8.256 −8.764 0.000
passgcar 0.741 0.438 0.744 0.437 0.003 0.775
suv 0.069 0.253 0.103 0.305 0.035 0.000
weekend 0.415 0.493 0.364 0.481 −0.051 0.000
frimp 0.629 0.483 0.730 0.444 0.101 0.000
indfrimp 0.182 0.386 0.158 0.364 −0.024 0.022
rearimp 0.055 0.229 0.029 0.167 −0.027 0.000
indrearimp 0.053 0.225 0.027 0.162 −0.026 0.000
rsimp 0.014 0.116 0.011 0.105 −0.002 0.451
lsimp 0.033 0.179 0.029 0.168 −0.004 0.379
death 0.070 0.255 0.028 0.165 −0.042 0.000
modelyr 1991.363 5.695 1981.364 9.127 −9.998 0.000
childseat 0.322 0.467 0.186 0.389 −0.136 0.000
lapbelt 0.173 0.378 0.127 0.334 −0.046 0.000
lapshould 0.276 0.447 0.124 0.330 −0.152 0.000
ruralrd 0.034 0.180 0.133 0.339 0.099 0.000
row1 0.247 0.431 0.328 0.470 0.081 0.000
backright 0.285 0.452 0.227 0.419 −0.058 0.000
backleft 0.285 0.451 0.231 0.422 −0.053 0.000
backother 0.023 0.149 0.026 0.160 0.003 0.419
male 0.504 0.500 0.534 0.499 0.030 0.030
missweight 0.166 0.372 0.058 0.234 −0.108 0.000
thoulbs_I 2.348 1.199 3.654 1.067 1.306 0.000
numcrash 7.262 7.061 6.580 2.873 −0.682 0.000
drivebelt 0.864 0.343 0.500 0.500 −0.364 0.000
splmU55 0.512 0.500 1.000 0.000 0.488 0.000
lowviol 0.295 0.456 0.271 0.445 −0.024 0.055
highviol 0.072 0.259 0.069 0.253 −0.003 0.624
Dbinary 0.771 0.420 0.437 0.496 −0.334 0.000
crashtm_1_day 0.789 0.408 0.905 0.294 0.115 0.000
crashtm_2_night 0.177 0.381 0.078 0.269 −0.098 0.000
crashtm_3_morn 0.034 0.181 0.017 0.129 −0.017 0.000
cate -0.071 0.007 -0.035 0.006 0.036 0.000
tau -0.086 0.022 -0.023 0.016 0.064 0.000
policy 0.844 0.363 0.099 0.299 −0.745 0.000

3.7 Mean CATEs by covariates

df_sel["tau"]<-predict(cfbinary)$predictions
df_sel_train_col<-df_sel%>%
  group_by(modelyr,splmU55)%>%
  summarise(tau=mean(tau))
p1<-ggplot(df_sel_train_col,aes(x=modelyr,y=tau,color=as.factor(splmU55)))+geom_point()+
  ylim(-0.125,0)+theme_classic()
df_sel_train_col<-df_sel%>%
  group_by(year,splmU55)%>%
  summarise(tau=mean(tau))
p2<-ggplot(df_sel_train_col,aes(x=year,y=tau,color=as.factor(splmU55)))+geom_point()+
  ylim(-0.125,0)+labs(y="")+theme_classic()+theme(axis.text.y=element_blank())
df_sel_train_col<-df_sel
df_sel_train_col["xtile"] <- ntile(as.numeric(df_sel_train_col$thoulbs_I), n=50)
df_sel_train_col<-df_sel_train_col%>%
  group_by(xtile)%>%
  summarise(thoulbs_I=mean(thoulbs_I),tau=mean(tau))%>%filter(thoulbs_I!=0)
p3<-ggplot(df_sel_train_col,aes(x=thoulbs_I,y=tau))+geom_point()+
  ylim(-0.125,0)+labs(y="")+theme_classic()+theme(axis.text.y=element_blank())
ggarrange(p1, p2, p3, ncol=3, nrow=1, common.legend = TRUE, legend="bottom")

3.8 CATE distribution by speed limit

# get predictions
cate<-data.frame(sample="CATEs",splmU55=df_sel$splmU55,tau=predict(cfbinary)$predictions)
# histogram all
ggplot(cate,aes(x=tau,fill=as.factor(splmU55),group=splmU55))+
   geom_histogram(aes(y=..count../sum(..count..)),bins=100,alpha=0.75, position = "identity",
                 color="white",size=0.1)+
  theme_minimal()+theme(legend.position="top")+
  labs(title=" ",x="Conditional Average Treatment Effect",y="Density", fill="splmU55")

3.9 Policy learning

Now to policy learning. We want to find the optimal policy. However, given that the predicted CATEs are positive for almost everyone it is expected that the unconditional optimal policy is close to treating everyone.

3.9.1 Features to consider

We only select a subset of features to base our policy on. We will for example not consider the type of accident because it doesn’t seem sensible to create a policy saying that you should have seatbelts if you were going to have a specific type of accident.

##  [1] "passgcar"        "suv"             "weekend"         "modelyr"        
##  [5] "ruralrd"         "row1"            "backright"       "backleft"       
##  [9] "backother"       "male"            "missweight"      "thoulbs_I"      
## [13] "numcrash"        "drivebelt"       "splmU55"         "lowviol"        
## [17] "highviol"        "crashtm_1_day"   "crashtm_2_night" "crashtm_3_morn"

3.9.2 Doubly robust scores

Next we compute the doubly robust scores. These scores are used to feed into the algorithm.

3.9.3 Find optimal policy

We now feed the scores to the algorithm that maximizes the difference between treating those directed by the features vs treating a random sample of the population. We set the depth to 2, meaning that the complexity of the policy is rather limited.

## policy_tree object 
## Tree depth:  2 
## Actions:  1:  2: Dbinary 
## Variable splits: 
## (1) split_variable: thoulbs_I  split_value: 1.939 
##   (2) split_variable: thoulbs_I  split_value: 1.869 
##     (4) * action: 2 
##     (5) * action: 1 
##   (3) split_variable: backother  split_value: 0 
##     (6) * action: 2 
##     (7) * action: 1

3.9.4 Plot optimal policy

3.9.5 Fraction treated

How many are treated?

df_sel["policy"]<-predict(opt.tree, df_sel%>%select(colnames(X_pol)))
df_sel<-df_sel%>%mutate(policy=policy-1)
mean(df_sel$policy)
## [1] 0.9772463

We observe that almost everyone is treated, as expected.

3.9.6 Advantage of policy

What is the benefit for those treated compared to those not treated?

get_advantage = function(policy) {
    benefits<-( policy ) *Gamma.dr

    # Treated
    ATT = mean(Gamma.dr[policy==1])
    ATTse = sqrt(var(Gamma.dr[policy==1]) / length(Gamma.dr[policy==1]))
    # Untreated
    ATU = mean(Gamma.dr[policy==0])
    ATUse = sqrt(var(Gamma.dr[policy==0]) / length(Gamma.dr[policy==0]))
    # output
    c(att=ATT,attse=ATTse,atu=ATU,atuse=ATUse)
}

get_advantage(df_sel$policy)
##          att        attse          atu        atuse 
## -0.050601107  0.006550859  0.241905622  0.175910798

3.10 Policy learning with a cost

We observe above that we treated basically everyone, as expected. But that also assumed that the policy was free. What if it was costly? We incldue a cost by simply reducing the benefit of treatment.

3.10.1 Doubly robust scores with costs

We add a cost corresponding to reducing the benefit by 0.04, which corresponds to a pretty high cost.

3.10.2 Find optimal policy

## policy_tree object 
## Tree depth:  2 
## Actions:  1:  2: Dbinary 
## Variable splits: 
## (1) split_variable: thoulbs_I  split_value: 1.962 
##   (2) split_variable: splmU55  split_value: 0 
##     (4) * action: 2 
##     (5) * action: 1 
##   (3) split_variable: thoulbs_I  split_value: 3.065 
##     (6) * action: 2 
##     (7) * action: 1

3.10.3 Plot optimal policy

Below we simply plot the policy. Action 2 is treating. Action 1 is not treating.

3.10.4 Fraction treated

df_sel["policy"]<-predict(opt.tree, df_sel%>%select(colnames(X_pol)))
df_sel<-df_sel%>%mutate(policy=policy-1)
mean(df_sel$policy)
## [1] 0.3384973

We now only treat half of the sample!

3.10.5 Advantage of policy

get_advantage(df_sel$policy)
##          att        attse          atu        atuse 
## -0.096897108  0.015133415 -0.016849658  0.008382543

And those that we treat now benefit a lot. But those that we don’t treat would also benefit somewhat.

# get predictions
cate<-data.frame(sample="CATEs",treated=df_sel$policy,tau=predict(cfbinary)$predictions)
# histogram all
ggplot(cate,aes(x=tau,fill=as.factor(treated),group=treated))+
   geom_histogram(aes(y=..count../sum(..count..)),bins=100,alpha=0.75, position = "identity",
                 color="white",size=0.1)+
  theme_minimal()+theme(legend.position="top")+
  labs(title=" ",x="Conditional Average Treatment Effect",y="Density", fill="Treated")